import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys

cwd = os.getcwd()
sys.path.append(cwd.replace('/interface', ''))
print(sys.path)

import numpy as np
from player_ranking.player_evaluation_metric import PlayerRanking, run_risk_sensitive_player_evaluation
from generic.model_util import get_distrib_q_model_save_path, get_maf_save_path
from agent import SportsAgent
from generic.data_util import load_config, read_args, ICEHOCKEY_ACTIONS, divide_dataset_according2date, read_player_info


def test(args):
    mode = 'all'
    test_num_tau = 64
    test_num_supp = 64
    test_gamma = 1
    test_train_rate = 0.8
    max_trace_length = 3
    apply_dynamic_trace_length = False
    test_apply_rnn = True
    test_apply_resnet = True
    test_cut_at_goal = True
    condition_on_action = False
    dqn_save_date = 'Nov-19-2021'
    maf_save_date = 'Jan-05-2022'
    load_dqn_episode = 19000
    load_maf_episode = 6000
    best_label = ''
    sanity_check_msg = None
    iteration = 'tmp'
    print_measures = ['P', 'A', 'G']

    config, debug_mode, log_file_path = load_config(args)
    if debug_mode:
        debug_msg = 'debug_'
        # config['general']['model']['max_trace_length'] = 1
        # config['general']['training']['batch_size'] = 64
    else:
        debug_msg = ''

    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None
    config['general']['model']['apply_rnn'] = test_apply_rnn
    config['general']['model']['apply_resnet'] = test_apply_resnet
    config['general']['model']['num_tau'] = test_num_tau
    config['general']['model']['num_supp'] = test_num_supp
    config['general']['training']['gamma'] = test_gamma
    config['general']['training']['cut_at_goal'] = test_cut_at_goal
    config['general']['training']['train_rate'] = test_train_rate
    config['general']['model']['apply_dynamic_trace_length'] = apply_dynamic_trace_length
    config['general']['model']['max_trace_length'] = max_trace_length
    config['general']['maf']['condition_on_action'] = condition_on_action

    # config['general']['use_cuda'] = False

    player_id_name = read_player_info()

    agent = SportsAgent(config=config, log_file=log_file)

    # read dqn model
    model_save_mother_dir = get_distrib_q_model_save_path(agent=agent,
                                                          date_label=dqn_save_date,
                                                          debug_msg='')
    model_save_mother_dir += best_label
    dqn_load_from_path = model_save_mother_dir + '/saved_model_{0}'.format(load_dqn_episode)
    _, _, _, _, _ = agent.load_pretrained_model(load_from=dqn_load_from_path,
                                                          load_optim=False,
                                                          log_file=log_file)

    # read maf model
    model_save_mother_dir = get_maf_save_path(agent=agent,
                                              date_label=maf_save_date,
                                              debug_msg='')
    maf_load_from_path = model_save_mother_dir + '/saved_model_{0}'.format(load_maf_episode)
    _, _ = agent.load_maf(load_from=maf_load_from_path,
                                             log_file=log_file)

    ranking_player_saved_dir = './player_ranking/' \
                               + debug_msg \
                               + 'rank' \
                               + model_save_mother_dir.split('saved')[-1] \
                               + '-mode_{0}'.format(mode) \
                               + '-epi_{0}'.format(load_maf_episode) \
                               + '/'
    if not os.path.exists(ranking_player_saved_dir):
        os.mkdir(ranking_player_saved_dir)

    rpr = run_risk_sensitive_player_evaluation(agent=agent,
                                               model_save_path=model_save_mother_dir,
                                               iteration=iteration,
                                               uncertainty_model='maf',
                                               # date=date,
                                               # gda_fitting_target=gda_fitting_target,
                                               mode=mode,
                                               log_file=log_file,
                                               sanity_check_msg=sanity_check_msg,
                                               debug_mode=debug_mode,
                                               debug_msg=debug_msg,
                                               compute_correlations=False)

    for uncertainty_threshold in agent.uncertainty_thresholds:
        # for uncertainty_threshold in [float('inf')]:
        for alpha in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
            player_name_rgim_dict = {}
            for pid in rpr.player_impact_dict_by_alpha[alpha].keys():
                valid_player_impacts = []
                for idx in range(len(rpr.player_impact_dict_by_alpha[alpha][pid])):
                    if rpr.player_uncertainty_dict_by_alpha[alpha][pid][idx] <= uncertainty_threshold:
                        valid_player_impacts.append(rpr.player_impact_dict_by_alpha[alpha][pid][idx])

                player_game_impact = np.sum(valid_player_impacts)
                if pid not in rpr.player_stats_dict.keys():
                    continue
                player_stats = [int(rpr.player_stats_dict[pid][rpr.interested_measures.index(measure)])
                                for measure in print_measures]
                player_stats += [round(player_game_impact, 2)]
                # player_stats += [int(pid)]
                player_name_rgim_dict.update({player_id_name[pid][0]: player_id_name[pid][1:] + player_stats})
            player_name_rgim_list = sorted(player_name_rgim_dict.items(),
                                           key=lambda x: x[1][-1], reverse=True)
            with open(ranking_player_saved_dir +
                      'ranking_alpha_{0}_uncertain_{1}.txt'.format(alpha, uncertainty_threshold),
                      'w') as rank_file:
                player_label_string = 'player name  & position & team '
                for measure in print_measures:
                    player_label_string += ' & ' + measure
                player_label_string += ' & RGIM ' + '\\\\ \n'
                rank_file.write(player_label_string)
                for player_name_rgim_tuple in player_name_rgim_list:
                    player_stats_string = ''
                    for stat in player_name_rgim_tuple[1]:
                        player_stats_string += ' & ' + str(stat)
                    rank_file.write(str(player_name_rgim_tuple[0]) + player_stats_string + '\\\\ \n')


if __name__ == "__main__":
    args = read_args()
    test(args)
